######################################################################
# Useful scripts for Ulsan Validation research
# Written by K. Bae @ Feb. 06, 2025, UNIST
# Python 3.12.5 on CentOS 7
######################################################################


def gbrs_info(stn_name, return_type=False, return_vaa=False):
    """Call the location and instrument type of the Ground-Based remote sensing instrument 
    station in Ulsan

    Args:
        stn_name (str): station name (UB_Ulsan, BIRA_Ulsan, P150)
        return_type (bool, optional): If True, return the type of instrument (Pandora, MAX-DOAS). Defaults to 'False'.

    Returns:
        stn_loc (list): lon, lat list of geolocation
        inst_type (str): type of instrument (Pandora, MAX-DOAS)
    """
    if stn_name == 'UB_Ulsan':
        stn_loc = [129.305838, 35.493230]
        vaa_list = [25, 90, 145, 345]
        inst_type = 'MAX-DOAS'
    elif stn_name == 'BIRA_Ulsan':
        stn_loc = [129.297356, 35.512979]
        inst_type = 'MAX-DOAS'
        vaa_list = [35]
    elif stn_name == 'P150':
        stn_loc = [129.1896, 35.5745]
        inst_type = 'Pandora'
        vaa_list = [67, 76.5, 105]
    if return_type:
        return stn_loc, inst_type
    if return_vaa:
        return stn_loc, vaa_list
    else:
        return stn_loc


def aqms_info(stn_name):
    if stn_name == 'P150':
        _stn_loc = [129.22192, 35.5706]
        _stn_code = 238378
    elif stn_name == 'UB_Ulsan' or stn_name == 'BIRA_Ulsan':
        _stn_loc = [129.30592, 35.49311]
        _stn_code = 238373
    return _stn_loc, _stn_code


def inst_type(stn):
    import re

    p = re.compile("P[0-9]")
    m = p.match(stn)

    if m:
        return "Pandora"
    else:
        return "MAX-DOAS"


def load_sat_ext(sat, stn, ext_type='normal', wind=False, era5_type='SL'):
    import pandas as pd
    import numpy as np
    from configure import data_path

    stn_type = inst_type(stn)
    if ext_type == 'normal':
        _suffix = 'ext'
    elif ext_type == 'line':
        _suffix = 'line_ext'

    if sat.lower() == 'gems':
        satname_path = 'GEMS_V3'
        date_format = '%Y-%m-%d %H:%M:%S'
    else:
        satname_path = sat.upper()
        date_format = '%Y-%m-%dT%H:%M:%S'

    if wind:
        ext_path = f'{data_path}/ERA5_{era5_type}_merge'
        ext_fname = f'{ext_path}/{stn}_{satname_path}_{_suffix}_SL_merge.csv'

    else:
        ext_fname = f'{data_path}/{sat.upper()}_{_suffix}/{stn}_{satname_path}_{_suffix}.csv'

    ext = pd.read_csv(ext_fname, sep=',', na_values='--')
    print(f'Read {ext_fname.split("/")[-1]}')
    ext['scantime'] = pd.to_datetime(ext['scantime'],
                                     format=date_format,
                                     exact=False)
    ext.set_index('scantime', inplace=True)

    if stn_type == 'Pandora':
        ext['RS_TCD'] = np.nan
        ext['RS_VAA'] = np.nan
    ext['RS_VCD'] = np.nan

    return ext


def weekday_weekend(datetime):
    from air_toolbox.util import find_dow
    dow = find_dow(datetime)
    if dow in ['SAT', 'SUN']:
        return 'Weekend'
    elif dow == 'None':
        return 'None'
    else:
        return 'Weekday'


def find_season(month):
    if month in [1, 2, 12]:
        return 'DJF'
    elif month in [3, 4, 5]:
        return 'MAM'
    elif month in [6, 7, 8]:
        return 'JJA'
    elif month in [9, 10, 11]:
        return 'SON'


def high_low(vcd):
    if vcd > 1e16:
        return 'High'
    else:
        return 'Low'


def year(dt):
    from datetime import datetime
    if (dt >= datetime(2021, 8, 1)) and (dt < datetime(2022, 8, 1)):
        year = 2021
    elif (dt >= datetime(2022, 8, 1)) and (dt < datetime(2023, 8, 1)):
        year = 2022
    elif (dt >= datetime(2023, 8, 1)) and (dt < datetime(2024, 8, 1)):
        year = 2023
    else:
        year = -999
    return year


class hourly_group:

    def __init__(self, df, sat_var):
        _colname = [
            f'GEMS_V3_{sat_var.upper()}', f'RS_{sat_var.upper()}', 'Hour'
        ]
        _df = df[_colname]
        self.groupby = _df.groupby('Hour')
        self.mean = self.groupby.mean()
        self.std = self.groupby.std()
        self.count = self.groupby.count()
        self.median = self.groupby.median()
        self.max = self.groupby.max()
        self.min = self.groupby.min()
        self.q1 = self.groupby.quantile(0.25)
        self.q3 = self.groupby.quantile(0.75)


def draw_line_from_point(ax, start_lon, start_lat, length_meters, angles_degrees, proj='epsg:3857', marker_size=50):
    """
    Draw dashed lines on the map starting from a point with given length and multiple angles.
    
    Parameters:
    -----------
    ax : matplotlib.axes.Axes
        The axes to draw on
    start_lon, start_lat : float
        Starting point coordinates in EPSG:4326 (longitude, latitude)
    length_meters : float
        Length of the line in meters
    angles_degrees : list or float
        List of angles from north in degrees (0 is north, 90 is east)
        If single float is provided, it will be converted to a single-item list
    proj : str
        Projection to use (default: 'epsg:3857')
    marker_size : int
        Size of the marker at the start point (default: 50)
    """
    from pyproj import Transformer
    import math
    # Convert single angle to list if necessary
    if isinstance(angles_degrees, (int, float)):
        angles_degrees = [angles_degrees]

    # Transform start point to meters
    transformer = Transformer.from_crs("epsg:4326", proj, always_xy=True)
    start_x, start_y = transformer.transform(start_lon, start_lat)
    
    # Plot marker at start point
    ax.scatter(start_x, start_y, c='k', s=marker_size, zorder=3)
    
    # Draw lines for each angle
    for angle in angles_degrees:
        # Convert angle to radians
        angle_rad = math.radians(angle)
        
        # Calculate end point in meters
        dx = length_meters * math.sin(angle_rad)
        dy = length_meters * math.cos(angle_rad)
        
        # Calculate end point in meters
        end_x = start_x + dx
        end_y = start_y + dy
        
        # Draw the line
        ax.plot([start_x, end_x], [start_y, end_y], 'k--', linewidth=2)

def add_shapefile(ax, shapefile_path, edgecolor='k', facecolor='none', linewidth=0.5, zorder=1):
    """
    Add a shapefile to the map.
    
    Parameters:
    -----------
    ax : matplotlib.axes.Axes
        The axes to draw on
    shapefile_path : str
        Path to the shapefile (.shp)
    edgecolor : str
        Color of the shapefile boundary (default: 'k' for black)
    facecolor : str
        Color of the shapefile fill (default: 'none' for transparent)
    linewidth : float
        Width of the boundary line (default: 0.5)
    zorder : int
        Drawing order (default: 1)
    """
    import geopandas as gpd
    # Read shapefile
    gdf = gpd.read_file(shapefile_path)
    
    # Project to Web Mercator (EPSG:3857)
    gdf = gdf.to_crs(epsg=3857)
    
    # Plot the shapefile
    gdf.plot(
        ax=ax,
        edgecolor=edgecolor,
        facecolor=facecolor,
        linewidth=linewidth,
        zorder=zorder
    )